package edu.northwestern.cbits.purple_robot_manager.models;
import java.io.StringReader;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import com.alexmerz.graphviz.ParseException;
import com.alexmerz.graphviz.Parser;
import com.alexmerz.graphviz.objects.Edge;
import com.alexmerz.graphviz.objects.Graph;
import com.alexmerz.graphviz.objects.Id;
import com.alexmerz.graphviz.objects.Node;
import edu.northwestern.cbits.purple_robot_manager.R;
import edu.northwestern.cbits.purple_robot_manager.logging.LogManager;
import android.content.Context;
import android.net.Uri;
import android.util.Log;
/**
* Implemements a trained decision tree model encoded using GraphViz generated
* by Weka:
*
* http://www.alexander-merz.com/graphviz/
*
* Note that this class does not train the model, but instead expects the
* representation of a model already trained.
*
* A sample tree representation (spaces added for readability:
*
* <pre>
* {@code
* digraph J48Tree {
* N0 [label="telephonyprobe_psc" ]
* N0->N1 [label="<= 305"]
* N1 [label="weatherundergroundfeature_visibility" ]
* N1->N2 [label="<= 8"]
* N2 [label="colleagues (2.0/1.0)" shape=box style=filled ]
* N1->N3 [label="> 8"]
* N3 [label="wifiaccesspointsprobe_access_point_count" ]
* N3->N4 [label="<= 11"]
* N4 [label="sunrisesunsetfeature_sunrise_distance" ]
* N4->N5 [label="<= 28478000"]
* N5 [label="family (7.0)" shape=box style=filled ]
* N4->N6 [label="> 28478000"]
* N6 [label="weatherundergroundfeature_temperature" ]
* N6->N7 [label="<= 25.3"]
* N7 [label="locationprobe_bearing" ]
* N7->N8 [label="<= 91.699997"]
* N8 [label="friends (3.0/1.0)" shape=box style=filled ]
* N7->N9 [label="> 91.699997"]
* N9 [label="alone (11.0/2.0)" shape=box style=filled ]
* N6->N10 [label="> 25.3"]
* N10 [label="colleagues (2.0)" shape=box style=filled ]
* N3->N11 [label="> 11"]
* N11 [label="wifiaccesspointsprobe_current_rssi" ]
* N11->N12 [label="<= -64"]
* N12 [label="friends (6.0/1.0)" shape=box style=filled ]
* N11->N13 [label="> -64"]
* N13 [label="family (6.0)" shape=box style=filled ]
* N0->N14 [label="> 305"]
* N14 [label="other (4.0/1.0)" shape=box style=filled ]}}
* </pre>
*/
public class WekaTreeModel extends TrainedModel
{
public static final String TYPE = "weka-decision-tree";
private Graph _tree = null;
public WekaTreeModel(Context context, Uri uri)
{
super(context, uri);
}
/**
* Parses Graph object from the string provided in model to generate the
* data structure that evaluates data to generate predictions.
*
* @see http://www.alexander-merz.com/graphviz/doc/com/alexmerz/graphviz/objects/Graph.html
* @see edu.northwestern.cbits.purple_robot_manager.models.TrainedModel#generateModel(android.content.Context,
* java.lang.Object)
*/
protected void generateModel(Context context, Object model)
{
StringReader reader = new StringReader(model.toString());
Parser parser = new Parser();
try
{
if (parser.parse(reader))
{
ArrayList<Graph> graphs = parser.getGraphs();
if (graphs.size() > 0)
this._tree = graphs.get(0);
}
}
catch (ParseException e)
{
LogManager.getInstance(context).logException(e);
}
}
/**
* Finds the root node of the tree and begins evaluating the model.
*
* @see edu.northwestern.cbits.purple_robot_manager.models.TrainedModel#evaluateModel(android.content.Context,
* java.util.Map)
*/
protected Object evaluateModel(Context context, Map<String, Object> snapshot)
{
if (this._tree == null)
return null;
Id rootId = new Id();
rootId.setId("N0");
Node root = this._tree.findNode(rootId);
return this.fetchPrediction(root, this._tree.getEdges(), snapshot);
}
/**
* Evaluates the state of the world in relation to the provided node. Based
* on the value of the node, recursively passes control to the next node in
* the evaluation sequence until encountering a leaf node containing the
* prediction. The prediction value is returned up the tree and becomes the
* final prediction for the model. Returns null if an error is encountered
* or data needed to evaluate the model is missing.
*
* @param node
* Node object representing the current location in the tree.
* @param edges
* Edges of the graph connecting nodes. Encodes comparison
* operators.
* @param snapshot
* States used to generate prediction.
*
* @return Prediction given states.
*/
protected String fetchPrediction(Node node, List<Edge> edges, Map<String, Object> snapshot)
{
synchronized (this)
{
String nodeLabel = node.getAttribute("label").replaceAll("_", "");
String[] tokens = nodeLabel.split(" ");
nodeLabel = tokens[tokens.length - 1];
List<Edge> testEdges = new ArrayList<>();
for (Edge edge : edges)
{
if (edge.getSource().getNode() == node)
testEdges.add(edge);
}
if (testEdges.size() == 0)
{
String prediction = node.getAttribute("label");
int colonIndex = prediction.indexOf(":");
prediction = prediction.substring(colonIndex + 1).trim();
int index = prediction.indexOf(" ");
if (index != -1)
prediction = prediction.substring(0, index).trim();
return prediction;
}
for (String key : snapshot.keySet())
{
String testKey = key.replaceAll("_", "");
if (testKey.equalsIgnoreCase(nodeLabel))
{
Object value = snapshot.get(key);
double testValue = Double.NaN;
if (value instanceof Integer)
testValue = ((Integer) value).doubleValue();
else if (value instanceof Double)
testValue = (Double) value;
else if (value instanceof Float)
testValue = ((Float) value).doubleValue();
else if (value instanceof Long)
testValue = ((Long) value).doubleValue();
Node nextNode = null;
for (Edge edge : testEdges)
{
String edgeLabel = edge.getAttribute("label").trim();
int index = edgeLabel.indexOf(" ");
String operation = edgeLabel.substring(0, index);
String edgeValue = edgeLabel.substring(index + 1);
if (Double.isNaN(testValue) == false)
{
double edgeQuantity = Double.parseDouble(edgeValue);
if ("<=".equals(operation))
{
if (testValue <= edgeQuantity)
nextNode = edge.getTarget().getNode();
}
else if (">=".equals(operation))
{
if (testValue >= edgeQuantity)
nextNode = edge.getTarget().getNode();
}
else if (">".equals(operation))
{
if (testValue > edgeQuantity)
nextNode = edge.getTarget().getNode();
}
else if ("<".equals(operation))
{
if (testValue < edgeQuantity)
nextNode = edge.getTarget().getNode();
}
else
Log.e("PR", "UNKNOWN OP: -" + operation + "-");
}
else if ("=".equals(operation))
{
String valueString = value.toString().replaceAll("\\.", "");
// ^ TODO: Normalize
if (nextNode == null && "= ?".equals(edgeLabel))
nextNode = edge.getTarget().getNode();
else if (valueString.equalsIgnoreCase(edgeValue))
nextNode = edge.getTarget().getNode();
}
else
Log.e("PR", "UNKNOWN OP: -" + operation + "-");
}
if (nextNode != null)
return this.fetchPrediction(nextNode, edges, snapshot);
}
}
}
return null;
}
public String modelType()
{
return WekaTreeModel.TYPE;
}
public String summary(Context context)
{
return context.getString(R.string.summary_model_tree);
}
}